
precision = float
command = speechTrain
deviceId = $DeviceId$       

parallelTrain = false

frameMode = false
truncated = false

speechTrain = [
    action = "train"
    
    traceLevel = 1
    modelPath = $RunDir$/models/simple.dnn
    
    SGD = [
        epochSize = 250
        minibatchSize = 20
        learningRatesPerMB = 0.1
        maxEpochs = 10
    ]
    
    reader = [
        verbosity = 0
        randomize = false
        
        # A list of deserializers the reader uses.
        deserializers = (
            [
                type = "HTKFeatureDeserializer"
                module = "HTKDeserializers"
                input = [
                    # Description of input stream to feed the Input node named "features"
                    features = [
                        dim=363
                        scpFile = "$DataDir$/glob_0000.scp"
                    ]
                ]
            ]:
            [
                type = "HTKMLFDeserializer"
                module = "HTKDeserializers"
                input = [
                    # Description of input stream to feed the Input node named "labels"
                    labels = [
                        dim = 133
                        mlfFile="$DataDir$/glob_0000.mlf"
                        labelMappingFile = "$DataDir$/state_ctc.list" 
                        labelType=Category
                        phoneBoundaries=true
                    ]
                ]
            ]
        )
    ]

    # define network using BrainScript
    BrainScriptNetworkBuilder = {
        
        WeightParam(m,n) = Parameter(m, n, init="uniform", initValueScale=1, initOnCPUOnly=true, randomSeed=1)
        BiasParam(m) = Parameter(m, 1, init="fixedValue", value=0.0)
        ScalarParam() = Parameter(1, 1, init="fixedValue", value=0.0)

        NewBeta() = Exp(ScalarParam())
        Stabilize(in) = Scale(NewBeta(), in)

        LSTMPComponentWithSelfStab(inputDim, outputDim, cellDim, inputx) =
        [
            // parameter macros--these carry their own weight matrices
            B() = BiasParam(cellDim)
            Wmr = WeightParam(outputDim, cellDim);

            W(v) = WeightParam(cellDim, inputDim) * Stabilize(v)    // input-to-hidden
            H(h) = WeightParam(cellDim, outputDim) * Stabilize(h)   // hidden-to-hidden
            C(c) = DiagTimes(WeightParam(cellDim, 1), Stabilize(c)) // cell-to-hiddden

            // LSTM cell
            dh = PastValue(outputDim, output);                   // hidden state(t-1)
            dc = PastValue(cellDim, ct);                         // cell(t-1)

            // note: the W(inputx) here are all different, they all come with their own set of weights; same for H(dh), C(dc), and B()
            it = Sigmoid(W(inputx) + B() + H(dh) + C(dc))       // input gate(t)
            bit = it .* Tanh(W(inputx) + (H(dh) + B()))         // applied to tanh of input network

            ft = Sigmoid(W(inputx) + B() + H(dh) + C(dc))       // forget-me-not gate(t)
            bft = ft .* dc                                          // applied to cell(t-1)

            ct = bft + bit                                          // c(t) is sum of both

            ot = Sigmoid(W(inputx) + B() + H(dh) + C(ct))       // output gate(t)
            mt = ot .* Tanh(ct)                                     // applied to tanh(cell(t))

            output = Wmr * Stabilize(mt)                            // projection
        ]

        // define basic I/O
        baseFeatDim = 33
        featDim = 11 * baseFeatDim
        labelDim = 133

        // hidden dimensions
        cellDim = 1024
        hiddenDim = 256
        numLSTMs = 1        

        // features
        features = Input{featDim}
        labels = Input{labelDim}
        feashift = RowSlice(featDim - baseFeatDim, baseFeatDim, features);      # shift 5 frames right (x_{t+5} -> x_{t} )

        featNorm = MeanVarNorm(feashift)

        // define the stack of hidden LSTM layers
        LSTMoutput[k:1..numLSTMs] = if k == 1
                                    then LSTMPComponentWithSelfStab(baseFeatDim, hiddenDim, cellDim, featNorm)
                                    else LSTMPComponentWithSelfStab(hiddenDim,   hiddenDim, cellDim, LSTMoutput[k-1].output)

        // and add a softmax layer on top
        W(in) = WeightParam(labelDim, hiddenDim) * Stabilize(in)
        B = BiasParam(labelDim)
        
        LSTMoutputW = W(LSTMoutput[numLSTMs].output) + B;

        // training
        graph = LabelsToGraph(labels)
        cr = ForwardBackward(graph, LSTMoutputW, 132, delayConstraint=3, tag="criterion") 
        Err = EditDistanceError(labels, LSTMoutputW, squashInputs=true, tokensToIgnore = 132, tag="evaluation") 

        // decoding
        logPrior = LogPrior(labels)
        ScaledLogLikelihood = Minus(LSTMoutputW, logPrior, tag="output")    
    }
]
